{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2080947c-96d9-447f-8368-cfdc9e5c9960",
   "metadata": {},
   "source": [
    "#  Using Semantic chunks with Gemini API and Gemini Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53221f1a-a0c1-4506-a3d0-d6626c58e4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Regular Imports\n",
    "import os\n",
    "import glob\n",
    "import time\n",
    "from dotenv import load_dotenv\n",
    "from tqdm.notebook import tqdm\n",
    "import gradio as gr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a2a7171-a7b6-42a6-96d7-c93f360689ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visual Import\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import TSNE\n",
    "import numpy as np\n",
    "import plotly.graph_objects as go"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51c9d658-65e5-40a1-8680-d0b561f87649",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Lang Chain Imports\n",
    "\n",
    "from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI\n",
    "from langchain_community.document_loaders import DirectoryLoader, TextLoader\n",
    "from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
    "from langchain_core.messages import HumanMessage, AIMessage\n",
    "from langchain_chroma import Chroma\n",
    "from langchain_experimental.text_splitter import SemanticChunker\n",
    "from langchain_core.chat_history import InMemoryChatMessageHistory\n",
    "from langchain_core.runnables.history import RunnableWithMessageHistory\n",
    "from langchain.chains.combine_documents import create_stuff_documents_chain\n",
    "from langchain.chains.history_aware_retriever import create_history_aware_retriever\n",
    "from langchain.chains import create_retrieval_chain\n",
    "from langchain_core.prompts import MessagesPlaceholder\n",
    "from langchain_core.runnables import RunnableLambda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e7ed82b-b28a-4094-9f77-3b6432dd0f7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Constants\n",
    "\n",
    "CHAT_MODEL = \"gemini-2.5-flash\"\n",
    "EMBEDDING_MODEL = \"models/text-embedding-004\"\n",
    "# EMBEDDING_MODEL_EXP = \"models/gemini-embedding-exp-03-07\"\n",
    "\n",
    "folders = glob.glob(\"knowledge-base/*\")\n",
    "text_loader_kwargs = {'encoding': 'utf-8'}\n",
    "db_name = \"vector_db\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b83281a2-bcae-41ab-a347-0e7f9688d1ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "load_dotenv(override=True)\n",
    "\n",
    "api_key =  os.getenv(\"GOOGLE_API_KEY\")\n",
    "\n",
    "if not api_key:\n",
    "    print(\"API Key not found!\")\n",
    "else:\n",
    "    print(\"API Key loaded in memory\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fd6d516-772b-478d-9b28-09d42f2277d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_metadata(doc, doc_type):\n",
    "    doc.metadata[\"doc_type\"] = doc_type\n",
    "    return doc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bc4198b-f989-42c0-95b5-3596448fcaa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "documents = []\n",
    "for folder in tqdm(folders, desc=\"Loading folders\"):\n",
    "    doc_type = os.path.basename(folder)\n",
    "    loader = DirectoryLoader(folder, glob=\"**/*.md\", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)\n",
    "    folder_docs = loader.load()\n",
    "    documents.extend([add_metadata(doc, doc_type) for doc in folder_docs])\n",
    "\n",
    "print(f\"Total documents loaded: {len(documents)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb74241f-e9d5-42e8-9a4b-f31018397d66",
   "metadata": {},
   "source": [
    "## Create Semantic Chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a3aa17f-f5d0-430a-80da-95c284bd99a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "chunking_embedding_model = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL, task_type=\"retrieval_document\")\n",
    "\n",
    "text_splitter = SemanticChunker(\n",
    "    chunking_embedding_model,\n",
    "    breakpoint_threshold_type=\"percentile\", \n",
    "    breakpoint_threshold_amount=95.0,     \n",
    "    min_chunk_size=3                      \n",
    ")\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "semantic_chunks = []\n",
    "pbar = tqdm(documents, desc=\"Semantic chunking documents\")\n",
    "\n",
    "for i, doc in enumerate(pbar):\n",
    "    doc_type = doc.metadata.get('doc_type', 'Unknown')\n",
    "    pbar.set_postfix_str(f\"Processing: {doc_type}\")\n",
    "    try:\n",
    "        doc_chunks = text_splitter.split_documents([doc])\n",
    "        semantic_chunks.extend(doc_chunks)\n",
    "    except Exception as e:\n",
    "        tqdm.write(f\"❌ Failed to split doc ({doc.metadata.get('source', 'unknown source')}): {e}\")\n",
    "print(f\"⏱️ Took {time.time() - start:.2f} seconds\")\n",
    "print(f\"Total semantic chunks: {len(semantic_chunks)}\")\n",
    "\n",
    "# import time\n",
    "# start = time.time()\n",
    "\n",
    "# try:\n",
    "#     semantic_chunks = text_splitter.split_documents(documents)\n",
    "#     print(f\"✅ Chunking completed with {len(semantic_chunks)} chunks\")\n",
    "# except Exception as e:\n",
    "#     print(f\"❌ Failed to split documents: {e}\")\n",
    "\n",
    "# print(f\"⏱️ Took {time.time() - start:.2f} seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "675b98d6-5ed0-45d1-8f79-765911e6badf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Some Preview of the chunks\n",
    "for i, doc in enumerate(semantic_chunks[:15]):\n",
    "    print(f\"--- Chunk {i+1} ---\")\n",
    "    print(doc.page_content) \n",
    "    print(\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c17accff-539a-490b-8a5f-b5ce632a3c71",
   "metadata": {},
   "source": [
    "## Embed with Gemini Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bd228bd-37d2-4aaf-b0f6-d94943f6f248",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL,task_type=\"retrieval_document\")\n",
    "\n",
    "if os.path.exists(db_name):\n",
    "    Chroma(persist_directory=db_name, embedding_function=embedding).delete_collection()\n",
    "\n",
    "vectorstore = Chroma.from_documents(\n",
    "    documents=semantic_chunks,\n",
    "    embedding=embedding,\n",
    "    persist_directory=db_name\n",
    ")\n",
    "\n",
    "print(f\"✅ Vectorstore created with {vectorstore._collection.count()} documents\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce0a3e23-5912-4de2-bf34-3c0936375de1",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Visualzing Vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ffdc6f5-ec25-4229-94d4-1fc6bb4d2702",
   "metadata": {},
   "outputs": [],
   "source": [
    "collection = vectorstore._collection\n",
    "result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
    "vectors = np.array(result['embeddings'])\n",
    "documents = result['documents']\n",
    "metadatas = result['metadatas']\n",
    "doc_types = [metadata['doc_type'] for metadata in metadatas]\n",
    "colors = [['blue', 'green', 'red', 'orange'][['products', 'employees', 'contracts', 'company'].index(t)] for t in doc_types]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5428164b-f0d5-4d2b-ac4a-514c43ceaa79",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We humans find it easier to visalize things in 2D!\n",
    "# Reduce the dimensionality of the vectors to 2D using t-SNE\n",
    "# (t-distributed stochastic neighbor embedding)\n",
    "\n",
    "tsne = TSNE(n_components=2, random_state=42)\n",
    "reduced_vectors = tsne.fit_transform(vectors)\n",
    "\n",
    "# Create the 2D scatter plot\n",
    "fig = go.Figure(data=[go.Scatter(\n",
    "    x=reduced_vectors[:, 0],\n",
    "    y=reduced_vectors[:, 1],\n",
    "    mode='markers',\n",
    "    marker=dict(size=5, color=colors, opacity=0.8),\n",
    "    text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(doc_types, documents)],\n",
    "    hoverinfo='text'\n",
    ")])\n",
    "\n",
    "fig.update_layout(\n",
    "    title='2D Chroma Vector Store Visualization',\n",
    "    scene=dict(xaxis_title='x',yaxis_title='y'),\n",
    "    width=800,\n",
    "    height=600,\n",
    "    margin=dict(r=20, b=10, l=10, t=40)\n",
    ")\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "359b8651-a382-4050-8bf8-123e5cdf4d53",
   "metadata": {},
   "source": [
    "## RAG Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08a75313-6c68-42e5-bd37-78254123094c",
   "metadata": {},
   "outputs": [],
   "source": [
    "retriever = vectorstore.as_retriever(search_kwargs={\"k\": 20 })\n",
    "\n",
    "# Conversation Memory\n",
    "# memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n",
    "\n",
    "chat_llm = ChatGoogleGenerativeAI(model=CHAT_MODEL, temperature=0.7)\n",
    "\n",
    "question_generator_template = \"\"\"Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.\n",
    "If the follow up question is already a standalone question, return it as is.\n",
    "\n",
    "Chat History:\n",
    "{chat_history}\n",
    "Follow Up Input: {input}  \n",
    "Standalone question:\"\"\"\n",
    "\n",
    "question_generator_prompt = ChatPromptTemplate.from_messages([\n",
    "    MessagesPlaceholder(variable_name=\"chat_history\"),\n",
    "    HumanMessagePromptTemplate.from_template(\"{input}\")\n",
    "])\n",
    "\n",
    "history_aware_retriever = create_history_aware_retriever(\n",
    "    chat_llm, retriever, question_generator_prompt\n",
    ")\n",
    "\n",
    "qa_system_prompt = \"\"\"You are Insurellm’s intelligent virtual assistant, designed to answer questions with accuracy and clarity. Respond naturally and helpfully, as if you're part of the team.\n",
    "Use the retrieved documents and prior conversation to provide accurate, conversational, and concise answers.Rephrase source facts in a natural tone, not word-for-word.\n",
    "When referencing people or company history, prioritize clarity and correctness.\n",
    "Only infer from previous conversation if it provides clear and factual clues. Do not guess or assume missing information.\n",
    "If you truly don’t have the answer, respond with:\n",
    "\"I don't have that information.\"\n",
    "Avoid repeating the user's wording unnecessarily. Do not refer to 'the context', speculate, or make up facts.\n",
    "\n",
    "{context}\"\"\"\n",
    "\n",
    "\n",
    "qa_human_prompt = \"{input}\" \n",
    "\n",
    "qa_prompt = ChatPromptTemplate.from_messages([\n",
    "    SystemMessagePromptTemplate.from_template(qa_system_prompt),\n",
    "    MessagesPlaceholder(variable_name=\"chat_history\"),\n",
    "    HumanMessagePromptTemplate.from_template(\"{input}\")\n",
    "])\n",
    "\n",
    "combine_docs_chain = create_stuff_documents_chain(chat_llm, qa_prompt)\n",
    "\n",
    "# inspect_context = RunnableLambda(lambda inputs: (\n",
    "#     print(\"\\n Retrieved Context:\\n\", \"\\n---\\n\".join([doc.page_content for doc in inputs[\"context\"]])),\n",
    "#     inputs  # pass it through unchanged\n",
    "# )[1])\n",
    "\n",
    "# inspect_inputs = RunnableLambda(lambda inputs: (\n",
    "#     print(\"\\n Inputs received by the chain:\\n\", inputs),\n",
    "#     inputs\n",
    "# )[1])\n",
    "\n",
    "base_chain = create_retrieval_chain(history_aware_retriever, combine_docs_chain)\n",
    "\n",
    "# Using Runnable Lambda as Gradio needs the response to contain only the output (answer) and base_chain would have a dict with input, context, chat_history, answer\n",
    "\n",
    "# base_chain_with_output = base_chain | inspect_context | RunnableLambda(lambda res: res[\"answer\"])\n",
    "# base_chain_with_output = base_chain | RunnableLambda(lambda res: res[\"answer\"])\n",
    "\n",
    "\n",
    "# Session Persistent Chat History \n",
    "# If we want to persist history between sessions then use MongoDB (or any non sql DB)to store and use MongoDBChatMessageHistory (relevant DB Wrapper)\n",
    "\n",
    "chat_histories = {}\n",
    "\n",
    "def get_history(session_id):\n",
    "    if session_id not in chat_histories:\n",
    "        chat_histories[session_id] = InMemoryChatMessageHistory()\n",
    "    return chat_histories[session_id]\n",
    "\n",
    "# Currently set to streaming ...if one shot response is needed then comment base_chain and output_message_key and enable base_chain_with_output\n",
    "conversation_chain = RunnableWithMessageHistory(\n",
    "    # base_chain_with_output,\n",
    "    base_chain,\n",
    "    get_history,\n",
    "    output_messages_key=\"answer\",        \n",
    "    input_messages_key=\"input\",\n",
    "    history_messages_key=\"chat_history\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06b58566-70cb-42eb-8b1c-9fe353fe71f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat(question, history):\n",
    "    try:\n",
    "        # result = conversation_chain.invoke({\"input\": question, \"chat_history\": memory.buffer_as_messages})\n",
    "        \n",
    "        # memory.chat_memory.add_user_message(question)\n",
    "        # memory.chat_memory.add_ai_message(result[\"answer\"])\n",
    "\n",
    "        # return result[\"answer\"]\n",
    "\n",
    "        \n",
    "        session_id = \"default-session\"\n",
    "\n",
    "        # # FUll chat version\n",
    "        # result = conversation_chain.invoke(\n",
    "        #     {\"input\": question},\n",
    "        #     config={\"configurable\": {\"session_id\": session_id}}\n",
    "        # )\n",
    "        # # print(result)\n",
    "        # return result\n",
    "\n",
    "        # Streaming Version\n",
    "        response_buffer = \"\"\n",
    "\n",
    "        for chunk in conversation_chain.stream({\"input\": question},config={\"configurable\": {\"session_id\": session_id}}):\n",
    "            if \"answer\" in chunk:\n",
    "                response_buffer += chunk[\"answer\"]\n",
    "                yield response_buffer \n",
    "    except Exception as e:\n",
    "        print(f\"An error occurred during chat: {e}\")\n",
    "        return \"I apologize, but I encountered an error and cannot answer that right now.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a577ac66-3952-4821-83d2-8a50bad89971",
   "metadata": {},
   "outputs": [],
   "source": [
    "view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56b63a17-2522-46e5-b5a3-e2e80e52a723",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
